/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/*!
 * \file sort_common_impl.h
 * \brief
 */
#ifndef IMPL_SORT_SORT_SORT_COMMON_IMPL_H
#define IMPL_SORT_SORT_SORT_COMMON_IMPL_H
#if defined(ASCENDC_CPU_DEBUG) && ASCENDC_CPU_DEBUG == 1
#include "kernel_log.h"
#endif

#include "kernel_tensor.h"
#include "kernel_utils.h"
#include "kernel_pop_stack_buffer.h"
#include "kernel_tiling/kernel_tiling.h"

#if __CCE_AICORE__ == 220
#include "sort_v220_impl.h"
#elif __CCE_AICORE__ == 200
#include "sort_v200_impl.h"
#elif __CCE_AICORE__ == 100
#include "sort_v100_impl.h"
#elif __CCE_AICORE__ == 300
#include "sort_v300_impl.h"
#endif

namespace AscendC {

constexpr uint8_t SORT_GATHER_MASK_EVEN = 1;
constexpr uint8_t SORT_GATHER_MASK_ODD = 2;
constexpr int32_t SORT_REGION_LABEL_POSITION = 5;
constexpr int32_t SORT_REGION_Y1_POSITION = 1;

#if ASCENDC_CPU_DEBUG
template <typename T>
__aicore__ inline void CheckSortImpl(const LocalTensor<T> &checkLocal, const std::string& tensorInfo)
{
    ASCENDC_ASSERT(((TPosition)checkLocal.GetPosition() == TPosition::VECIN ||
        (TPosition)checkLocal.GetPosition() == TPosition::VECOUT ||
        (TPosition)checkLocal.GetPosition() == TPosition::VECCALC), {
        KERNEL_LOG(KERNEL_ERROR,
            "Failed to check %s tensor position in Sort, support positions are VECIN, VECOUT, VECCALC.",
                tensorInfo.c_str());
    });
    auto checkLocalCpuPtr = checkLocal.GetPhyAddr();
    uint64_t realAddr = (uint64_t)checkLocalCpuPtr -
        (uint64_t)(GetTPipePtr()->GetBaseAddr(static_cast<int8_t>(TPosition::VECCALC)));
    ASCENDC_ASSERT((realAddr % ONE_BLK_SIZE == 0), {
        KERNEL_LOG(KERNEL_ERROR,
            "Failed to check %s tensor adddress alignment in Sort, current tensor address is %lu, "
            "which should be 32 byte aligned.", tensorInfo.c_str(), realAddr);
    });
}

template <typename T, bool isFullSort>
__aicore__ inline void CheckSort(const LocalTensor<T> &dstLocal, const LocalTensor<T> &concatLocal,
    const LocalTensor<uint32_t> &indexLocal, LocalTensor<T> &tmpLocal, const int32_t repeatTimes)
{
    CheckSortImpl<T>(dstLocal, "dstLocal");
    CheckSortImpl<T>(concatLocal, "concatLocal");
    CheckSortImpl<uint32_t>(indexLocal, "indexLocal");
    CheckSortImpl<T>(tmpLocal, "tmpLocal");
}
#endif

template <typename T, bool isFullSort>
__aicore__ inline void SortImpl(const LocalTensor<T> &dstLocal, const LocalTensor<T> &concatLocal,
    const LocalTensor<uint32_t> &indexLocal, LocalTensor<T> &tmpLocal, const int32_t repeatTimes)
{
    ASCENDC_ASSERT((SupportType<T, half, float>()), {KERNEL_LOG(KERNEL_ERROR, "Failed to check dtype in Sort, current "
        "api support dtype combination is src and dst both: half / float");});
    ASCENDC_ASSERT((repeatTimes <= 255 && repeatTimes >= 0), { KERNEL_LOG(KERNEL_ERROR,
        "Failed to check repeatTimes value in Sort, its valid range is 0 ~ 255, current value is %d. ", repeatTimes);
    });
#if ASCENDC_CPU_DEBUG
    CheckSort<T, isFullSort>(dstLocal, concatLocal, indexLocal, tmpLocal, repeatTimes);
#endif
#if __CCE_AICORE__ >= 220
    Sort32(dstLocal, concatLocal, indexLocal, repeatTimes);
#elif __CCE_AICORE__ <= 200
    if (indexLocal.GetSize() != 0) {
        if constexpr (IsSameType<T, half>::value) {
            uint64_t rsvdCnt = 0;
            // sort process 16-elem each repeat, while gatherMask process 64-elem(uint32_t) each repeat
            // repeat time for gather mask is 1/4 of sort's repeat time
            // align repeat time to 64-elem
            constexpr uint16_t SORT_ELEM_PER_REPEAT = 16;
            constexpr uint16_t GATHER_ELEM_PER_REPEAT = 64;
            const uint16_t gatherRepTimes = (repeatTimes * SORT_ELEM_PER_REPEAT + GATHER_ELEM_PER_REPEAT - 1) /
                GATHER_ELEM_PER_REPEAT;

            LocalTensor<T> indexTensor = indexLocal.ReinterpretCast<T>();
            GatherMask(dstLocal, indexTensor, SORT_GATHER_MASK_EVEN, false,
                (uint32_t)0, {1, gatherRepTimes, DEFAULT_REPEAT_STRIDE, 0}, rsvdCnt);
            PipeBarrier<PIPE_V>();
            ProposalConcat(concatLocal, dstLocal, repeatTimes, SORT_REGION_Y1_POSITION);
            PipeBarrier<PIPE_V>();
            GatherMask(dstLocal, indexTensor, SORT_GATHER_MASK_ODD, false,
                (uint32_t)0, {1, gatherRepTimes, DEFAULT_REPEAT_STRIDE, 0}, rsvdCnt);
            PipeBarrier<PIPE_V>();
            ProposalConcat(concatLocal, dstLocal, repeatTimes, SORT_REGION_LABEL_POSITION);
        } else {
            ProposalConcat(concatLocal, indexLocal.ReinterpretCast<T>(), (uint16_t)repeatTimes,
                           SORT_REGION_LABEL_POSITION);
        }
        PipeBarrier<PIPE_V>();
    }
    RpSort16(dstLocal, concatLocal, repeatTimes);
#endif
    if constexpr (isFullSort) {
        PipeBarrier<PIPE_V>();
        DoFullSort(dstLocal, concatLocal, indexLocal, tmpLocal, repeatTimes);
    }
}

}  // namespace AscendC

#endif  // IMPL_SORT_SORT_SORT_COMMON_IMPL_H